{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "395e5929",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "#全局变量\n",
    "hub_token = open('/root/hub_token.txt').read().strip()\n",
    "repo_id = 'lansinuote/diffusion.4.text_to_image'\n",
    "push_to_hub = True\n",
    "checkpoint = 'CompVis/stable-diffusion-v1-4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "099d4db3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DDPMScheduler {\n",
       "   \"_class_name\": \"DDPMScheduler\",\n",
       "   \"_diffusers_version\": \"0.12.1\",\n",
       "   \"beta_end\": 0.012,\n",
       "   \"beta_schedule\": \"scaled_linear\",\n",
       "   \"beta_start\": 0.00085,\n",
       "   \"clip_sample\": false,\n",
       "   \"num_train_timesteps\": 1000,\n",
       "   \"prediction_type\": \"epsilon\",\n",
       "   \"set_alpha_to_one\": false,\n",
       "   \"skip_prk_steps\": true,\n",
       "   \"steps_offset\": 1,\n",
       "   \"trained_betas\": null,\n",
       "   \"variance_type\": \"fixed_small\"\n",
       " },\n",
       " CLIPTokenizer(name_or_path='CompVis/stable-diffusion-v1-4', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken(\"<|startoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import DDPMScheduler\n",
    "from transformers import CLIPTokenizer\n",
    "\n",
    "#加载两个工具类\n",
    "scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')\n",
    "tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')\n",
    "\n",
    "scheduler, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "69854652",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lambdalabs--pokemon-blip-captions-10e3527a764857bd\n",
      "Found cached dataset parquet (/root/.cache/huggingface/datasets/lambdalabs___parquet/lambdalabs--pokemon-blip-captions-10e3527a764857bd/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['image', 'text'],\n",
      "        num_rows: 833\n",
      "    })\n",
      "}) {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x1280 at 0x7FCFD1469C40>, 'text': 'a drawing of a green pokemon with red eyes'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/lambdalabs___parquet/lambdalabs--pokemon-blip-captions-10e3527a764857bd/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-2d5136ee1c738262.arrow\n",
      "Pushing split train to the Hub.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/99.7M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/833 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n",
      "Dataset({\n",
      "    features: ['image', 'input_ids'],\n",
      "    num_rows: 833\n",
      "}) {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x1280 at 0x7FCFD13F5D90>, 'input_ids': [49406, 320, 3610, 539, 320, 1901, 9528, 593, 736, 3095, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]}\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "def get_dataset():\n",
    "    #加载数据集\n",
    "    dataset = load_dataset('lambdalabs/pokemon-blip-captions')\n",
    "    \n",
    "    print(dataset, dataset['train'][0])\n",
    "    \n",
    "    #文字编码\n",
    "    def f(data):\n",
    "        #77 = tokenizer.model_max_length\n",
    "        token = tokenizer.batch_encode_plus(data['text'],\n",
    "                                            max_length=77,\n",
    "                                            padding='max_length',\n",
    "                                            truncation=True,\n",
    "                                            return_tensors=None)\n",
    "\n",
    "        data['input_ids'] = token['input_ids']\n",
    "\n",
    "        return data\n",
    "\n",
    "\n",
    "    dataset = dataset.map(f,\n",
    "                          batched=True,\n",
    "                          batch_size=100,\n",
    "                          num_proc=1,\n",
    "                          remove_columns=['text'])\n",
    "\n",
    "    return dataset\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    dataset = get_dataset()\n",
    "    dataset.push_to_hub(repo_id=repo_id, token=hub_token)\n",
    "\n",
    "#直接使用我处理好的数据集\n",
    "dataset = load_dataset(path=repo_id, split='train')\n",
    "\n",
    "print(dataset, dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9d783d32",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['image', 'input_ids'],\n",
      "    num_rows: 833\n",
      "}) {'pixel_values': tensor([[[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         ...,\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.]],\n",
      "\n",
      "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         ...,\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.]],\n",
      "\n",
      "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         ...,\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "         [1., 1., 1.,  ..., 1., 1., 1.]]]), 'input_ids': [49406, 320, 3610, 539, 320, 1901, 9528, 593, 736, 3095, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]}\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "\n",
    "#图像增强\n",
    "compose = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(\n",
    "        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),\n",
    "    torchvision.transforms.CenterCrop(512),\n",
    "    torchvision.transforms.RandomHorizontalFlip(),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize([0.5], [0.5]),\n",
    "])\n",
    "\n",
    "\n",
    "#应用图像数据增强\n",
    "def f(data):\n",
    "    pixel_values = [compose(i) for i in data['image']]\n",
    "\n",
    "    return {'pixel_values': pixel_values, 'input_ids': data['input_ids']}\n",
    "\n",
    "\n",
    "#因为图像增强在每一个epoch中是动态计算的,所以不能简单地用map处理\n",
    "dataset = dataset.with_transform(f)\n",
    "\n",
    "print(dataset, dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "43037e7c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(833,\n",
       " {'pixel_values': tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            ...,\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.]],\n",
       "  \n",
       "           [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            ...,\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.]],\n",
       "  \n",
       "           [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            ...,\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "            [1., 1., 1.,  ..., 1., 1., 1.]]]]),\n",
       "  'input_ids': tensor([[49406,   320,  3610,   539,   320,  7651,  9465,  5050,   320,  4987,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407]])})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义loader\n",
    "def collate_fn(data):\n",
    "    pixel_values = torch.stack([i['pixel_values'] for i in data])\n",
    "    input_ids = torch.LongTensor([i['input_ids'] for i in data])\n",
    "    return {'pixel_values': pixel_values, 'input_ids': input_ids}\n",
    "\n",
    "\n",
    "loader = torch.utils.data.DataLoader(dataset,\n",
    "                                     shuffle=True,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     batch_size=1)\n",
    "\n",
    "len(loader), next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "217c0a9c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder 12306.048\n",
      "vae 8365.3863\n",
      "unet 85952.0964\n"
     ]
    }
   ],
   "source": [
    "from diffusers import AutoencoderKL, UNet2DConditionModel\n",
    "from transformers import CLIPTextModel\n",
    "\n",
    "#加载3个模型\n",
    "encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')\n",
    "vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')\n",
    "unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')\n",
    "\n",
    "vae.requires_grad_(False)\n",
    "encoder.requires_grad_(False)\n",
    "unet.enable_gradient_checkpointing()\n",
    "\n",
    "\n",
    "def print_model_size(name, model):\n",
    "    print(name, sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "\n",
    "print_model_size('encoder', encoder)\n",
    "print_model_size('vae', vae)\n",
    "print_model_size('unet', unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a5f838ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(AdamW (\n",
       " Parameter Group 0\n",
       "     amsgrad: False\n",
       "     betas: (0.9, 0.999)\n",
       "     capturable: False\n",
       "     eps: 1e-08\n",
       "     foreach: None\n",
       "     lr: 1e-05\n",
       "     maximize: False\n",
       "     weight_decay: 0.01\n",
       " ),\n",
       " MSELoss())"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "optimizer = torch.optim.AdamW(unet.parameters(),\n",
    "                              lr=1e-5,\n",
    "                              betas=(0.9, 0.999),\n",
    "                              weight_decay=0.01,\n",
    "                              eps=1e-8)\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "optimizer, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3f32e9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0850, grad_fn=<MseLossBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(data):\n",
    "    device = data['input_ids'].device\n",
    "\n",
    "    #文字编码\n",
    "    #[1, 77] -> [1, 77, 768]\n",
    "    out_encoder = encoder(data['input_ids'])[0]\n",
    "\n",
    "    #抽取图像特征图\n",
    "    #[1, 3, 512, 512] -> [1, 4, 64, 64]\n",
    "    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()\n",
    "    #0.18215 = vae.config.scaling_factor\n",
    "    out_vae = out_vae * 0.18215\n",
    "\n",
    "    #随机数,unet的计算目标\n",
    "    noise = torch.randn_like(out_vae)\n",
    "\n",
    "    #往特征图中添加噪声\n",
    "    #1000 = scheduler.num_train_timesteps\n",
    "    #1 = out_vae.shape[0]\n",
    "    noise_step = torch.randint(0, 1000, (1, )).long()\n",
    "    noise_step = noise_step.to(device)\n",
    "    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)\n",
    "\n",
    "    #根据文字信息,把特征图中的噪声计算出来\n",
    "    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample\n",
    "\n",
    "    #计算mse loss\n",
    "    #[1, 4, 64, 64],[1, 4, 64, 64]\n",
    "    return criterion(out_unet, noise)\n",
    "\n",
    "\n",
    "get_loss({\n",
    "    'input_ids': torch.ones(1, 77).long(),\n",
    "    'pixel_values': torch.randn(1, 3, 512, 512)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "84b88599",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cloning https://huggingface.co/lansinuote/diffusion.4.text_to_image into local empty directory.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 11.485293690289836\n",
      "2 22.174073058733484\n",
      "4 21.349158763390733\n",
      "6 21.735034728772007\n",
      "8 21.165746794169536\n",
      "10 19.808746901675477\n",
      "12 21.422606183303287\n",
      "14 19.779738241253654\n",
      "16 20.875592649244936\n",
      "18 20.723903047357453\n",
      "20 20.24909785448108\n",
      "22 21.206063293473562\n",
      "24 20.186157436284702\n",
      "26 20.302006809739396\n",
      "28 19.989948059926974\n",
      "30 20.634496434227913\n",
      "32 20.226739437202923\n",
      "34 19.87426323920954\n",
      "36 20.088639282184886\n",
      "38 20.3696921497758\n",
      "40 19.518499594909372\n",
      "42 19.35646905520116\n",
      "44 19.544753054447938\n",
      "46 18.9495884028438\n",
      "48 19.509514296558336\n",
      "50 20.509902719059028\n",
      "52 19.261830052666483\n",
      "54 19.733987627958413\n",
      "56 19.839617400284624\n",
      "58 19.995629528842983\n",
      "60 20.015139478753554\n",
      "62 19.416784361950704\n",
      "64 18.512044332295773\n",
      "66 19.14228089810058\n",
      "68 18.59874345047865\n",
      "70 19.604682393954135\n",
      "72 18.631113044335507\n",
      "74 18.933612779015675\n",
      "76 19.089062765357085\n",
      "78 18.24768439900072\n",
      "80 18.213059518297086\n",
      "82 19.6351735486096\n",
      "84 17.98689945116348\n",
      "86 18.266969905249425\n",
      "88 17.443939730350394\n",
      "90 19.13027249123843\n",
      "92 18.595039716441534\n",
      "94 18.144658451201394\n",
      "96 17.69253787994967\n",
      "98 17.593372554067173\n",
      "100 18.93941926685511\n",
      "102 18.434513391999644\n",
      "104 17.44973034906434\n",
      "106 16.89524385180266\n",
      "108 17.700713564569014\n",
      "110 17.172305612359196\n",
      "112 18.20459977351129\n",
      "114 16.766683003355865\n",
      "116 17.56939655293536\n",
      "118 16.291864767161314\n",
      "120 17.255157917024917\n",
      "122 17.37772032187786\n",
      "124 17.324315564343124\n",
      "126 16.778003203260596\n",
      "128 17.46415755758062\n",
      "130 16.670364410601906\n",
      "132 17.570376775198383\n",
      "134 16.4826207121514\n",
      "136 16.441178401655634\n",
      "138 17.320468926394824\n",
      "140 16.65217541824677\n",
      "142 16.829039189331525\n",
      "144 16.47900081503758\n",
      "146 16.629376866912935\n",
      "148 16.64803347547422\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)0b28/vae/config.json:   0%|          | 0.00/551 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file unet/diffusion_pytorch_model.bin:   0%|          | 32.0k/3.20G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file vae/diffusion_pytorch_model.bin:   0%|          | 32.0k/319M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file safety_checker/pytorch_model.bin:   0%|          | 32.0k/1.13G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file text_encoder/pytorch_model.bin:   0%|          | 32.0k/470M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "from huggingface_hub import Repository, create_repo\n",
    "\n",
    "\n",
    "def train():\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    unet.to(device)\n",
    "    encoder.to(device)\n",
    "    vae.to(device)\n",
    "    unet.train()\n",
    "\n",
    "    loss_sum = 0\n",
    "    for epoch in range(150):\n",
    "        for i, data in enumerate(loader):\n",
    "            for k in data.keys():\n",
    "                data[k] = data[k].to(device)\n",
    "\n",
    "            loss = get_loss(data) / 4\n",
    "            loss.backward()\n",
    "            loss_sum += loss.item()\n",
    "\n",
    "            if (epoch * len(loader) + i) % 4 == 0:\n",
    "                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)\n",
    "                optimizer.step()\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "        if epoch % 2 == 0:\n",
    "            print(epoch, loss_sum)\n",
    "            loss_sum = 0\n",
    "\n",
    "    #保存\n",
    "    StableDiffusionPipeline.from_pretrained(\n",
    "        checkpoint, text_encoder=encoder, vae=vae,\n",
    "        unet=unet).save_pretrained('./save')\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    create_repo(repo_id, exist_ok=True, token=hub_token)\n",
    "    repo = Repository('./save', clone_from=repo_id, token=hub_token)\n",
    "    train()\n",
    "    repo.push_to_hub()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
